# -*-Python-*-

# Talking-heads version of t5.1.1.base.gin
# https://arxiv.org/abs/2003.02436
# 24 talking heads with d_kv=32

include 'models/t5.1.1.base.gin'

encoder/transformer.make_layer_stack.layers = [
    @mesh_tensorflow.transformer.transformer_layers.TalkingHeadsSelfAttention,
    @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
]
decoder/transformer.make_layer_stack.layers = [
    @mesh_tensorflow.transformer.transformer_layers.TalkingHeadsSelfAttention,
    @mesh_tensorflow.transformer.transformer_layers.TalkingHeadsEncDecAttention,
    @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
]

num_heads = 24
TalkingHeadsSelfAttention.key_heads_dims = [("key_heads", %num_heads)]
TalkingHeadsSelfAttention.softmax_heads_dims = [("heads", %num_heads)]
TalkingHeadsSelfAttention.value_heads_dims = [("value_heads", %num_heads)]
TalkingHeadsSelfAttention.key_size = 32
TalkingHeadsSelfAttention.value_size = 32
TalkingHeadsSelfAttention.dropout_rate = %dropout_rate
TalkingHeadsSelfAttention.relative_attention_type = "bias_shared"
